# -*- coding: utf-8 -*-

import colorsys
import socket
import threading
import struct
import os
import time
import numpy
import cv2
from timeit import default_timer as timer

import numpy as np
from PIL import Image, ImageFont, ImageDraw
from keras import backend as K
from keras.layers import Input
from keras.models import load_model
from keras.utils import multi_gpu_model

from vision.yolo3.model import yolo_eval, yolo_body, tiny_yolo_body
from vision.yolo3.utils import letterbox_image
from settings import *
import sys

sys.path.append("..")


def online_client():
    print("Establishing Connection...")
    cam = webCamConnect()
    print("Resolution: %d x %d (Edit settings.py to modify)" % (cam.resolution[0], cam.resolution[1]))
    print("Remote IP and Port: %s:%d" % (cam.remoteAddress[0], cam.remoteAddress[1]))
    try:
        cam.connect()
        cam.getData(cam.interval)
    except:
        print("Remote server seems to be shut off!")


def rtsp(rtspAdress):
    detect_camera(rtspAdress)


def detect_video(video_path, output_path=""):
    import cv2
    vid = cv2.VideoCapture(video_path)
    if not vid.isOpened():
        raise IOError("Couldn't open video")
    video_FourCC = cv2.VideoWriter_fourcc(*'XVID')
    video_fps = vid.get(cv2.CAP_PROP_FPS)
    video_size = (int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),
                  int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    isOutput = True if output_path != "" else False
    if isOutput:
        print("!!! TYPE:", type(output_path), type(video_FourCC), type(video_fps), type(video_size))
        out = cv2.VideoWriter(output_path, video_FourCC, video_fps, video_size)
    accum_time = 0
    curr_fps = 0
    fps = "FPS: ??"
    prev_time = timer()
    while True:
        return_value, frame = vid.read()
        if not return_value:
            break
        image = Image.fromarray(frame)
        image = yolo.detect_image(image, isWeb=False)
        LiveFeed = np.asarray(image)
        curr_time = timer()
        exec_time = curr_time - prev_time
        prev_time = curr_time
        accum_time = accum_time + exec_time
        curr_fps = curr_fps + 1
        if accum_time > 1:
            accum_time = accum_time - 1
            fps = "FPS: " + str(curr_fps)
            curr_fps = 0
        cv2.putText(LiveFeed, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=0.50, color=(255, 0, 0), thickness=2)
        cv2.namedWindow("LiveFeed", cv2.WINDOW_NORMAL)
        cv2.imshow("LiveFeed", LiveFeed)
        if isOutput:
            out.write(LiveFeed)
        if cv2.waitKey(10) == 27:
            break
    yolo.close_session()


def detect_image(image, isWeb):
    if isWeb is False:
        result = yolo.detect_image(image, False)
    else:
        result = yolo.detect_image(image, True)
    return result


def detect_camera(cam):
    """
    sending each frame from camera to detect_image of class YOLO.
    """
    import cv2
    vid = cv2.VideoCapture(cam)
    if not vid.isOpened():
        raise IOError("Couldn't open webcam!Please Check cable connection or driver installation!")
    accum_time = 0
    curr_fps = 0
    fps = "FPS: ??"
    prev_time = timer()
    while True:
        return_value, frame = vid.read()
        if not return_value:
            break
        image = Image.fromarray(frame)
        image = yolo.detect_image(image, isWeb=False)
        LiveFeed = np.asarray(image)
        curr_time = timer()
        exec_time = curr_time - prev_time
        prev_time = curr_time
        accum_time = accum_time + exec_time
        curr_fps = curr_fps + 1
        if accum_time > 1:
            accum_time = accum_time - 1
            fps = "FPS: " + str(curr_fps)
            curr_fps = 0
        cv2.putText(LiveFeed, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=0.50, color=(255, 0, 0), thickness=2)
        cv2.namedWindow("LiveFeed", cv2.WINDOW_NORMAL)
        cv2.imshow("LiveFeed", LiveFeed)
        if cv2.waitKey(10) == 27:
            break
    yolo.close_session()


def detect_img():
    while True:
        img = input('Input image filename(Press Ctrl+C to stop):')
        try:
            image = Image.open(img)
        except:
            print('Open Error!Check if the image exists and try again!')
            continue
        else:
            r_image = yolo.detect_image(image, isWeb=False)
            r_image.show()
    yolo.close_session()


class YOLO(object):  # YOLO main class
    _defaults = dict(model_path=MODEL_PATH, anchors_path=ANCHORS_PATH,
                     classes_path=CLASSES_PATH, score=0.3, iou=0.45, model_image_size=(416, 416),
                     gpu_num=1)

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)  # set up default values
        self.__dict__.update(kwargs)  # and update with user overrides
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()
        self.boxes, self.scores, self.classes = self.generate()

    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    def _get_anchors(self):
        anchors_path = os.path.expanduser(self.anchors_path)
        with open(anchors_path) as f:
            anchors = f.readline()
        anchors = [float(x) for x in anchors.split(',')]
        return np.array(anchors).reshape(-1, 2)

    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'

        # Load model, or construct model and load weights.
        num_anchors = len(self.anchors)
        num_classes = len(self.class_names)  # type: int
        is_tiny_version = num_anchors == 6  # default setting
        try:
            self.yolo_model = load_model(model_path, compile=False)
        except:
            self.yolo_model = tiny_yolo_body(Input(shape=(None, None, 3)), num_anchors // 2, num_classes) \
                if is_tiny_version else yolo_body(Input(shape=(None, None, 3)), num_anchors // 3, num_classes)
            self.yolo_model.load_weights(self.model_path)  # make sure model, anchors and classes match
        else:
            assert self.yolo_model.layers[-1].output_shape[-1] == \
                   num_anchors / len(self.yolo_model.output) * (num_classes + 5), \
                'Mismatch between model and given anchor and class sizes'

        print('{} model, anchors, and classes loaded,please wait...'.format(model_path))

        # Generate colors for drawing bounding boxes.
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
        np.random.seed(10101)  # Fixed seed for consistent colors across runs.
        np.random.shuffle(self.colors)  # Shuffle colors to decorrelate adjacent classes.
        np.random.seed(None)  # Reset seed to default.

        # Generate output tensor targets for filtered bounding boxes.
        self.input_image_shape = K.placeholder(shape=(2,))
        if self.gpu_num >= 2:
            self.yolo_model = multi_gpu_model(self.yolo_model, gpus=self.gpu_num)
        boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,
                                           len(self.class_names), self.input_image_shape,
                                           score_threshold=self.score, iou_threshold=self.iou)
        return boxes, scores, classes

    def detect_image(self, image, isWeb):
        start = timer()

        if self.model_image_size != (None, None):
            assert self.model_image_size[0] % 32 == 0, 'Multiples of 32 required'
            assert self.model_image_size[1] % 32 == 0, 'Multiples of 32 required'
            boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size)))
        else:
            new_image_size = (image.width - (image.width % 32),
                              image.height - (image.height % 32))
            boxed_image = letterbox_image(image, new_image_size)
        image_data = np.array(boxed_image, dtype='float32')

        print(image_data.shape)
        image_data /= 255.
        image_data = np.expand_dims(image_data, 0)  # Add batch dimension.
        out_boxes, out_scores, out_classes = self.sess.run(
            [self.boxes, self.scores, self.classes],
            feed_dict={
                self.yolo_model.input: image_data,
                self.input_image_shape: [image.size[1], image.size[0]]
                # remove K.learning_phase(): 0 to avoid error.
            })

        print('Found {} boxes for {}'.format(len(out_boxes), 'img'))

        font = ImageFont.truetype(font='font/Sansus-Webissimo-Regular.otf',
                                  size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
        thickness = (image.size[0] + image.size[1]) // 300

        for i, c in reversed(list(enumerate(out_classes))):
            predicted_class = self.class_names[c]
            box = out_boxes[i]
            score = out_scores[i]

            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)

            top, left, bottom, right = box
            top = max(0, np.floor(top + 0.5).astype('int32'))
            left = max(0, np.floor(left + 0.5).astype('int32'))
            bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))
            right = min(image.size[0], np.floor(right + 0.5).astype('int32'))
            print(label, (left, top), (right, bottom))

            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle(
                    [left + i, top + i, right - i, bottom - i],
                    outline=self.colors[c])
            draw.rectangle(
                [tuple(text_origin), tuple(text_origin + label_size)],
                fill=self.colors[c])
            draw.text(text_origin, label, fill=(0, 0, 0), font=font)
            del draw

        end = timer()
        print(end - start)
        if isWeb is True:
            return (image, out_boxes, out_scores, [self.class_names[i] for i in out_classes])
        else:
            return image

    def close_session(self):
        self.sess.close()


class webCamConnect:
    def __init__(self, resolution=RESOLUTION, remoteAddress=(REMOTE_IP,
                                                             REMOTE_PORT), windowName="Online Stream"):
        self.remoteAddress = remoteAddress
        self.resolution = resolution
        self.name = windowName
        self.mutex = threading.Lock()
        self.src = 911 + FPS  # verification number 911 need to coordinate with server
        self.interval = 0
        self.path = os.getcwd()
        self.img_quality = ONLINE_STREAM_QUALITY

    def _setSocket(self):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    def connect(self):
        self._setSocket()
        self.socket.connect(self.remoteAddress)

    def _processImage(self):
        self.socket.send(struct.pack("lhh", self.src, self.resolution[0], self.resolution[1]))
        accum_time = 0
        curr_fps = 0
        fps = "FPS: ??"
        prev_time = timer()
        while 1:
            info = struct.unpack("lhh", self.socket.recv(8))
            bufSize = info[0]
            if bufSize:
                try:
                    self.mutex.acquire()
                    self.buf = b''
                    temp_buf = self.buf
                    while bufSize:  # Get Size of image
                        temp_buf = self.socket.recv(bufSize)
                        bufSize -= len(temp_buf)
                        self.buf += temp_buf
                        data = numpy.fromstring(self.buf, dtype='uint8')  # Transfer to image array
                        self.image = cv2.imdecode(data, 1)  # Decode
                        Feed = Image.fromarray(self.image)  # YOLOv3 needs image format
                        LiveFeed = yolo.detect_image(Feed, isWeb=False)
                        LiveFeed = np.asarray(LiveFeed)  # opencv needs ndarray format
                        curr_time = timer()
                        exec_time = curr_time - prev_time
                        prev_time = curr_time
                        accum_time = accum_time + exec_time
                        curr_fps = curr_fps + 1
                        if accum_time > 1:
                            accum_time = accum_time - 1
                            fps = "FPS: " + str(curr_fps)
                            curr_fps = 0
                        cv2.putText(LiveFeed, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.50,
                                    color=(255, 0, 0), thickness=2)
                        cv2.namedWindow(self.name, cv2.WINDOW_NORMAL)
                        cv2.imshow(self.name, LiveFeed)  # show image frame by frame
                except:
                    print("Receive Failure!Check Connection!")
                    pass
                finally:
                    self.mutex.release()
                    if cv2.waitKey(10) == 27:
                        self.socket.close()
                        cv2.destroyAllWindows()
                        yolo.close_session()
                        print("Abort Connection!")
                        break

    def getData(self, interval):
        showThread = threading.Thread(target=self._processImage)
        showThread.start()
        if interval != 0:  # Need to modify this file to manually start
            saveThread = threading.Thread(target=self._savePicToLocal, args=(interval,
                                                                             ))
            saveThread.setDaemon(1)
            saveThread.start()

    def setWindowName(self, name):
        self.name = name

    def setRemoteAddress(self, remoteAddress):
        self.remoteAddress = remoteAddress

    def _savePicToLocal(self, interval):
        while 1:
            try:
                self.mutex.acquire()
                path = os.getcwd() + "\\" + "savePic"
                if not os.path.exists(path):
                    os.mkdir(path)
                cv2.imwrite(path + "\\" + time.strftime("%Y%m%d-%H%M%S",
                                                        time.localtime(time.time())) + ".jpg", self.image)
            except:
                pass
            finally:
                self.mutex.release()
                time.sleep(interval)


yolo = YOLO()  # initializing yolo
